{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "name": "Keras Flowers transfer learning (playground).ipynb",
   "provenance": [],
   "collapsed_sections": [],
   "toc_visible": true
  },
  "environment": {
   "name": "tf22-cpu.2-2.m47",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/tf22-cpu.2-2:m47"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YOykQi5n4e5u"
   },
   "source": [
    "Training on GPU will be fine for transfer learning as it is not a very demanding process."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "89B27-TGiDNB"
   },
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "9u3d4Z7uQsmp",
    "outputId": "d62575df-5c93-47ae-c9e9-fbf675efee45"
   },
   "source": [
    "import os, sys, math\n",
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "import tensorflow as tf\n",
    "print(\"Tensorflow version \" + tf.__version__)\n",
    "AUTOTUNE = tf.data.AUTOTUNE"
   ],
   "execution_count": 1,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Tensorflow version 2.6.0\n"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "w9S3uKC_iXY5"
   },
   "source": [
    "## Configuration"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "M3G-2aUBQJ-H"
   },
   "source": [
    "GCS_PATTERN = 'gs://flowers-public/tfrecords-jpeg-192x192-2/*.tfrec'\n",
    "IMAGE_SIZE = [192, 192]\n",
    "\n",
    "BATCH_SIZE = 64 # 128 works on GPU too but comes very close to the memory limit of the Colab GPU\n",
    "EPOCHS = 5\n",
    "\n",
    "VALIDATION_SPLIT = 0.19\n",
    "CLASSES = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips'] # do not change, maps to the labels in the data (folder names)\n",
    "\n",
    "# splitting data files between training and validation\n",
    "filenames = tf.io.gfile.glob(GCS_PATTERN)\n",
    "split = int(len(filenames) * VALIDATION_SPLIT)\n",
    "training_filenames = filenames[split:]\n",
    "validation_filenames = filenames[:split]\n",
    "print(\"Pattern matches {} data files. Splitting dataset into {} training files and {} validation files\".format(len(filenames), len(training_filenames), len(validation_filenames)))\n",
    "validation_steps = int(3670 // len(filenames) * len(validation_filenames)) // BATCH_SIZE\n",
    "steps_per_epoch = int(3670 // len(filenames) * len(training_filenames)) // BATCH_SIZE\n",
    "print(\"With a batch size of {}, there will be {} batches per training epoch and {} batch(es) per validation run.\".format(BATCH_SIZE, steps_per_epoch, validation_steps))"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "cellView": "form",
    "id": "MPkvHdAYNt9J"
   },
   "source": [
    "#@title display utilities [RUN ME]\n",
    "\n",
    "def dataset_to_numpy_util(dataset, N):\n",
    "  dataset = dataset.batch(N)\n",
    "  \n",
    "  for images, labels in dataset:\n",
    "    numpy_images = images.numpy()\n",
    "    numpy_labels = labels.numpy()\n",
    "    break;\n",
    "\n",
    "  return numpy_images, numpy_labels\n",
    "\n",
    "def title_from_label_and_target(label, correct_label):\n",
    "  correct = (label == correct_label)\n",
    "  return \"{} [{}{}{}]\".format(CLASSES[label], str(correct), ', shoud be ' if not correct else '',\n",
    "                              CLASSES[correct_label] if not correct else ''), correct\n",
    "\n",
    "def display_one_flower(image, title, subplot, red=False):\n",
    "    plt.subplot(subplot)\n",
    "    plt.axis('off')\n",
    "    plt.imshow(image)\n",
    "    plt.title(title, fontsize=16, color='red' if red else 'black')\n",
    "    return subplot+1\n",
    "  \n",
    "def display_9_images_from_dataset(dataset):\n",
    "  subplot=331\n",
    "  plt.figure(figsize=(13,13))\n",
    "  images, labels = dataset_to_numpy_util(dataset, 9)\n",
    "  for i, image in enumerate(images):\n",
    "    title = CLASSES[labels[i]]\n",
    "    subplot = display_one_flower(image, title, subplot)\n",
    "    if i >= 8:\n",
    "      break;\n",
    "              \n",
    "  #plt.tight_layout()\n",
    "  plt.subplots_adjust(wspace=0.1, hspace=0.1)\n",
    "  plt.show()\n",
    "  \n",
    "def display_9_images_with_predictions(images, predictions, labels):\n",
    "  subplot=331\n",
    "  plt.figure(figsize=(13,13))\n",
    "  classes = np.argmax(predictions, axis=-1)\n",
    "  for i, image in enumerate(images):\n",
    "    title, correct = title_from_label_and_target(classes[i], labels[i])\n",
    "    subplot = display_one_flower(image, title, subplot, not correct)\n",
    "    if i >= 8:\n",
    "      break;\n",
    "              \n",
    "  #plt.tight_layout()\n",
    "  plt.subplots_adjust(wspace=0.1, hspace=0.1)\n",
    "  plt.show()\n",
    "  \n",
    "def display_training_curves(training, validation, title, subplot):\n",
    "  if subplot%10==1: # set up the subplots on the first call\n",
    "    plt.subplots(figsize=(10,10), facecolor='#F0F0F0')\n",
    "    #plt.tight_layout()\n",
    "  ax = plt.subplot(subplot)\n",
    "  ax.set_facecolor('#F8F8F8')\n",
    "  ax.plot(training)\n",
    "  ax.plot(validation)\n",
    "  ax.set_title('model '+ title)\n",
    "  ax.set_ylabel(title)\n",
    "  ax.set_xlabel('epoch')\n",
    "  ax.legend(['train', 'valid.'])"
   ],
   "execution_count": 3,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kvPXiovhi3ZZ"
   },
   "source": [
    "## Read images and labels from TFRecords"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "LtAVr-4CP1rp"
   },
   "source": [
    "def read_tfrecord(example):\n",
    "    features = {\n",
    "        \"image\": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring\n",
    "        \"class\": tf.io.FixedLenFeature([], tf.int64),  # shape [] means scalar\n",
    "    }\n",
    "    example = tf.io.parse_single_example(example, features)\n",
    "    image = tf.io.decode_jpeg(example['image'], channels=3)\n",
    "    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range\n",
    "    image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size will be needed for TPU\n",
    "    class_label = example['class']\n",
    "    return image, class_label\n",
    "\n",
    "def load_dataset(filenames):\n",
    "  # read from TFRecords. For optimal performance, read from multiple\n",
    "  # TFRecord files at once and set the option experimental_deterministic = False\n",
    "  # to allow order-altering optimizations.\n",
    "\n",
    "  option_no_order = tf.data.Options()\n",
    "  option_no_order.experimental_deterministic = False\n",
    "\n",
    "  dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)\n",
    "  dataset = dataset.with_options(option_no_order)\n",
    "  dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)\n",
    "  return dataset"
   ],
   "execution_count": 4,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "xb-b4PRz-V6O"
   },
   "source": [
    "display_9_images_from_dataset(load_dataset(training_filenames))"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "yCDq52zgRHtH"
   },
   "source": [
    "## training and validation datasets"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "N5Y9XywVRHtK"
   },
   "source": [
    "def get_batched_dataset(filenames, train=False):\n",
    "  dataset = load_dataset(filenames)\n",
    "  dataset = dataset.cache() # This dataset fits in RAM\n",
    "  if train:\n",
    "    # Best practices for Keras:\n",
    "    # Training dataset: repeat then batch\n",
    "    # Evaluation dataset: do not repeat\n",
    "    dataset = dataset.repeat()\n",
    "  dataset = dataset.batch(BATCH_SIZE)\n",
    "  dataset = dataset.prefetch(AUTOTUNE) # prefetch next batch while training (autotune prefetch buffer size)\n",
    "  # should shuffle too but this dataset was well shuffled on disk already\n",
    "  return dataset\n",
    "  # source: Dataset performance guide: https://www.tensorflow.org/guide/performance/datasets\n",
    "\n",
    "# instantiate the datasets\n",
    "training_dataset = get_batched_dataset(training_filenames, train=True)\n",
    "validation_dataset = get_batched_dataset(validation_filenames, train=False)"
   ],
   "execution_count": 6,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ALtRUlxhw8Vt"
   },
   "source": [
    "## Model [WORK REQUIRED]\n",
    "1. Start with a dummy single-layer model using one dense layer:\n",
    " * Use  a [`tf.keras.Sequential`](https://www.tensorflow.org/api_docs/python/tf/keras/models/Sequential) model. The constructor takes a list of layers.\n",
    " * First, [`Flatten()`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Flatten) the pixel values of the the input image to a 1D vector so that a dense layer can consume it:<br/>\n",
    " `tf.keras.layers.Flatten(input_shape=[*IMAGE_SIZE, 3])  # the first layer must also specify input shape`\n",
    " * Add a single [`tf.keras.layers.Dense`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense) layer with softmax activation and the correct number of units (hint: 5 classes of flowers):<br/>\n",
    "  `tf.keras.layers.Dense(5, activation='softmax')`\n",
    " * add the last bits and pieces with [model.compile()](https://www.tensorflow.org/api_docs/python/tf/keras/models/Model#compile). For a classifier, you need `'sparse_categorical_crossentropy'` loss, `'accuracy'` in metrics and you can use the `'adam'` optimizer.\n",
    "\n",
    " **==>Train this model: not very good... but all the plumbing is in place.**\n",
    "\n",
    "1. Instead of trying to figure out a better architecture, we will adapt a pretrained model to our data. Please remove all your layers to restart from scratch.\n",
    " * Instantiate a pre-trained model from `tf.keras.applications.*`\n",
    " You do not need its final softmax layer (`include_top=False`) because you will be adding your own. This code is already written in the cell below.<br/>\n",
    " * Use `pretrained_model` as your first \"layer\" in your Sequential model.\n",
    " * Follow with  [`tf.keras.layers.Flatten()`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Flatten) or [`tf.keras.layers.GlobalAveragePooling2D()`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/GlobalAveragePooling2D) to turn the data from the pretrained model into a flat 1D vector.\n",
    " * Add your [`tf.keras.layers.Dense`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense) layer with softmax activation and the correct number of units (hint: 5 classes of flowers).\n",
    "\n",
    " **==>Train the model: you should be able to reach above 75% accuracy by training for 10 epochs**\n",
    "\n",
    "1. You can try adding a second dense layer. Use 'relu' activation on all dense layers but the last one which must be 'softmax'. An additional layer ads trainable weights. It is unlikely to do much good here though, because our dataset is too small.\n",
    "\n",
    "This technique is called \"transfer learning\". The pretrained model has been trained on a different dataset but its layers have still learned to recognize bits and pieces of images that can be useful for flowers. You are retraining the last layer only, the pretrained weights are frozen. With far fewer weights to adjust, it works with less data."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "XLJNVGwHUDy1"
   },
   "source": [
    "pretrained_model = tf.keras.applications.MobileNetV2(input_shape=[*IMAGE_SIZE, 3], include_top=False)\n",
    "#pretrained_model = tf.keras.applications.VGG16(weights='imagenet', include_top=False ,input_shape=[*IMAGE_SIZE, 3])\n",
    "#pretrained_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False, input_shape=[*IMAGE_SIZE, 3])\n",
    "#pretrained_model = tf.keras.applications.MobileNet(weights='imagenet', include_top=False, input_shape=[*IMAGE_SIZE, 3])\n",
    "pretrained_model.trainable = False\n",
    "\n",
    "model = tf.keras.Sequential([\n",
    "  #\n",
    "  # YOUR CODE HERE\n",
    "  #\n",
    "])\n",
    "\n",
    "model.compile(\n",
    "  #\n",
    "  # YOUR CODE HERE\n",
    "  # \n",
    ")\n",
    "\n",
    "model.summary()"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dMfenMQcxAAb"
   },
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "M-ID7vP5mIKs"
   },
   "source": [
    "history = model.fit(training_dataset, steps_per_epoch=steps_per_epoch, epochs=EPOCHS,\n",
    "                    validation_data=validation_dataset)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "VngeUBIdyJ1T"
   },
   "source": [
    "print(history.history.keys())\n",
    "display_training_curves(history.history['accuracy'], history.history['val_accuracy'], 'accuracy', 211)\n",
    "display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 212)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MKFMWzh0Yxsq"
   },
   "source": [
    "## Predictions"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "yMEsR851VDZb"
   },
   "source": [
    "# random input: execute multiple times to change results\n",
    "flowers, labels = dataset_to_numpy_util(load_dataset(validation_filenames).skip(np.random.randint(300)), 9)\n",
    "\n",
    "predictions = model.predict(flowers, steps=1)\n",
    "print(np.array(CLASSES)[np.argmax(predictions, axis=-1)].tolist())"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "qzCCDL1CZFx6"
   },
   "source": [
    "display_9_images_with_predictions(flowers, predictions, labels)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "W4pcQx3k4dUt"
   },
   "source": [
    "## License\n",
    "\n",
    "\n",
    "\n",
    "---\n",
    "\n",
    "\n",
    "author: Martin Gorner<br>\n",
    "twitter: @martin_gorner\n",
    "\n",
    "\n",
    "---\n",
    "\n",
    "\n",
    "Copyright 2021 Google LLC\n",
    "\n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "you may not use this file except in compliance with the License.\n",
    "You may obtain a copy of the License at\n",
    "\n",
    "    http://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "Unless required by applicable law or agreed to in writing, software\n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "See the License for the specific language governing permissions and\n",
    "limitations under the License.\n",
    "\n",
    "\n",
    "---\n",
    "\n",
    "\n",
    "This is not an official Google product but sample code provided for an educational purpose\n"
   ]
  }
 ]
}